{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DeepSpeech Conversion\n",
    "\n",
    "This example demonstrates the workflow to download a publicly available \n",
    "TensorFlow model and convert it to Core ML format using the tfcoreml converter.\n",
    "\n",
    "We use an open source implementation of DeepSpeech model (https://arxiv.org/abs/1412.5567) \n",
    "provided by Mozilla: https://github.com/mozilla/DeepSpeech.\n",
    "\n",
    "Note that this notebook is tested using following dependencies: \n",
    "\n",
    "```\n",
    "tensorflow==1.14.0\n",
    "coremltools==3.0\n",
    "```\n",
    "\n",
    "It will **NOT** work on TensorFlow 2.0+ because of deprecated and removed `tf.contrib` module in TensorFlow."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tfcoreml\n",
    "import numpy as np\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download TensorFlow implementation of DeepSpeech model from https://github.com/mozilla/DeepSpeech\n",
    "# wget https://github.com/mozilla/DeepSpeech/releases/download/v0.4.1/deepspeech-0.4.1-models.tar.gz\n",
    "# tar xvfz deepspeech-0.4.1-models.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 assert nodes deleted\n",
      "['h2:0', 'lstm_fused_cell/kernel:0', 'lstm_fused_cell/kernel/read:0', 'lstm_fused_cell/SequenceMask/Const_1:0', 'lstm_fused_cell/ExpandDims_1/dim:0', 'lstm_fused_cell/ExpandDims_2/dim:0', 'Reshape/shape:0', 'lstm_fused_cell/SequenceMask/Const:0', 'h6:0', 'lstm_fused_cell/SequenceMask/Range:0', 'zeros/shape_as_tensor:0', 'Minimum_2/y:0', 'h5/read:0', 'lstm_fused_cell/ExpandDims/dim:0', 'lstm_fused_cell/range_1/delta:0', 'b1/read:0', 'lstm_fused_cell/range/delta:0', 'Minimum/y:0', 'h1/read:0', 'b5:0', 'lstm_fused_cell/concat_1/axis:0', 'lstm_fused_cell/range/start:0', 'b1:0', 'b2/read:0', 'lstm_fused_cell/Const:0', 'h5:0', 'zeros/Const:0', 'lstm_fused_cell/SequenceMask/Const_2:0', 'b3/read:0', 'lstm_fused_cell/zeros/shape_as_tensor:0', 'zeros:0', 'lstm_fused_cell/range_1/limit:0', 'Minimum_1/y:0', 'raw_logits/shape:0', 'lstm_fused_cell/bias/read:0', 'lstm_fused_cell/zeros:0', 'lstm_fused_cell/Tile/multiples:0', 'Reshape_1/shape:0', 'lstm_fused_cell/zeros/Const:0', 'h1:0', 'b5/read:0', 'lstm_fused_cell/range:0', 'h3:0', 'lstm_fused_cell/range_1/start:0', 'b6/read:0', 'b6:0', 'h6/read:0', 'Reshape_2/shape:0', 'b2:0', 'lstm_fused_cell/range/limit:0', 'lstm_fused_cell/bias:0', 'lstm_fused_cell/transpose/perm:0', 'h3/read:0', 'h2/read:0', 'lstm_fused_cell/SequenceMask/ExpandDims/dim:0', 'b3:0', 'lstm_fused_cell/range_1:0', 'Minimum_3/y:0', 'lstm_fused_cell/concat/axis:0', 'transpose/perm:0']\n",
      "25 nodes deleted\n",
      "0 nodes deleted\n",
      "0 nodes deleted\n",
      "2 identity nodes deleted\n",
      "5 disconnected nodes deleted\n",
      "[SSAConverter] Converting function main ...\n",
      "[SSAConverter] [1/76] Converting op type: 'Placeholder', name: 'input_node', output_shape: (1, 16, 19, 26).\n",
      "[SSAConverter] [2/76] Converting op type: 'Placeholder', name: 'input_lengths', output_shape: (1,).\n",
      "[SSAConverter] [3/76] Converting op type: 'get_global', name: 'previous_state_c/read', output_shape: (1, 2048).\n",
      "[SSAConverter] [4/76] Converting op type: 'get_global', name: 'previous_state_h/read', output_shape: (1, 2048).\n",
      "[SSAConverter] [5/76] Converting op type: 'Const', name: 'transpose/perm', output_shape: (4,).\n",
      "[SSAConverter] [6/76] Converting op type: 'Const', name: 'Reshape/shape', output_shape: (2,).\n",
      "[SSAConverter] [7/76] Converting op type: 'Const', name: 'b1/read', output_shape: (2048,).\n",
      "[SSAConverter] [8/76] Converting op type: 'Const', name: 'Minimum/y'.\n",
      "[SSAConverter] [9/76] Converting op type: 'Const', name: 'b2/read', output_shape: (2048,).\n",
      "[SSAConverter] [10/76] Converting op type: 'Const', name: 'Minimum_1/y'.\n",
      "[SSAConverter] [11/76] Converting op type: 'Const', name: 'b3/read', output_shape: (2048,).\n",
      "[SSAConverter] [12/76] Converting op type: 'Const', name: 'Minimum_2/y'.\n",
      "[SSAConverter] [13/76] Converting op type: 'Const', name: 'Reshape_1/shape', output_shape: (3,).\n",
      "[SSAConverter] [14/76] Converting op type: 'Const', name: 'lstm_fused_cell/kernel/read', output_shape: (4096, 8192).\n",
      "[SSAConverter] [15/76] Converting op type: 'Const', name: 'lstm_fused_cell/bias/read', output_shape: (8192,).\n",
      "[SSAConverter] [16/76] Converting op type: 'Const', name: 'lstm_fused_cell/SequenceMask/Range', output_shape: (16,).\n",
      "[SSAConverter] [17/76] Converting op type: 'Const', name: 'lstm_fused_cell/SequenceMask/ExpandDims/dim'.\n",
      "[SSAConverter] [18/76] Converting op type: 'Const', name: 'lstm_fused_cell/transpose/perm', output_shape: (2,).\n",
      "[SSAConverter] [19/76] Converting op type: 'Const', name: 'lstm_fused_cell/ExpandDims/dim', output_shape: (1,).\n",
      "[SSAConverter] [20/76] Converting op type: 'Const', name: 'lstm_fused_cell/Tile/multiples', output_shape: (3,).\n",
      "[SSAConverter] [21/76] Converting op type: 'Const', name: 'lstm_fused_cell/ExpandDims_1/dim', output_shape: (1,).\n",
      "[SSAConverter] [22/76] Converting op type: 'Const', name: 'lstm_fused_cell/concat/axis'.\n",
      "[SSAConverter] [23/76] Converting op type: 'Const', name: 'lstm_fused_cell/ExpandDims_2/dim', output_shape: (1,).\n",
      "[SSAConverter] [24/76] Converting op type: 'Const', name: 'lstm_fused_cell/concat_1/axis'.\n",
      "[SSAConverter] [25/76] Converting op type: 'Const', name: 'lstm_fused_cell/range', output_shape: (1,).\n",
      "[SSAConverter] [26/76] Converting op type: 'Const', name: 'lstm_fused_cell/range_1', output_shape: (1,).\n",
      "[SSAConverter] [27/76] Converting op type: 'Const', name: 'Reshape_2/shape', output_shape: (2,).\n",
      "[SSAConverter] [28/76] Converting op type: 'Const', name: 'b5/read', output_shape: (2048,).\n",
      "[SSAConverter] [29/76] Converting op type: 'Const', name: 'Minimum_3/y'.\n",
      "[SSAConverter] [30/76] Converting op type: 'Const', name: 'b6/read', output_shape: (29,).\n",
      "[SSAConverter] [31/76] Converting op type: 'Const', name: 'raw_logits/shape', output_shape: (3,).\n",
      "[SSAConverter] [32/76] Converting op type: 'Transpose', name: 'transpose', output_shape: (16, 1, 19, 26).\n",
      "[SSAConverter] [33/76] Converting op type: 'ExpandDims', name: 'lstm_fused_cell/SequenceMask/ExpandDims', output_shape: (1, 1).\n",
      "[SSAConverter] [34/76] Converting op type: 'ExpandDims', name: 'lstm_fused_cell/ExpandDims_1', output_shape: (1, 1, 2048).\n",
      "[SSAConverter] [35/76] Converting op type: 'ExpandDims', name: 'lstm_fused_cell/ExpandDims_2', output_shape: (1, 1, 2048).\n",
      "[SSAConverter] [36/76] Converting op type: 'Pack', name: 'lstm_fused_cell/stack', output_shape: (1, 2).\n",
      "[SSAConverter] [37/76] Converting op type: 'Pack', name: 'lstm_fused_cell/stack_1', output_shape: (1, 2).\n",
      "[SSAConverter] [38/76] Converting op type: 'Reshape', name: 'Reshape', output_shape: (16, 494).\n",
      "[SSAConverter] [39/76] Converting op type: 'Cast', name: 'lstm_fused_cell/SequenceMask/Cast', output_shape: (1, 1).\n",
      "[SSAConverter] [40/76] Converting op type: 'MatMul', name: 'MatMul', output_shape: (16, 2048).\n",
      "[SSAConverter] [41/76] Converting op type: 'Less', name: 'lstm_fused_cell/SequenceMask/Less', output_shape: (1, 16).\n",
      "[SSAConverter] [42/76] Converting op type: 'Add', name: 'Add', output_shape: (16, 2048).\n",
      "[SSAConverter] [43/76] Converting op type: 'Cast', name: 'lstm_fused_cell/SequenceMask/Cast_1', output_shape: (1, 16).\n",
      "[SSAConverter] [44/76] Converting op type: 'Relu', name: 'Relu', output_shape: (16, 2048).\n",
      "[SSAConverter] [45/76] Converting op type: 'Transpose', name: 'lstm_fused_cell/transpose', output_shape: (16, 1).\n",
      "[SSAConverter] [46/76] Converting op type: 'Minimum', name: 'Minimum', output_shape: (16, 2048).\n",
      "[SSAConverter] [47/76] Converting op type: 'ExpandDims', name: 'lstm_fused_cell/ExpandDims', output_shape: (16, 1, 1).\n",
      "[SSAConverter] [48/76] Converting op type: 'MatMul', name: 'MatMul_1', output_shape: (16, 2048).\n",
      "[SSAConverter] [49/76] Converting op type: 'Tile', name: 'lstm_fused_cell/Tile', output_shape: (16, 1, 2048).\n",
      "[SSAConverter] [50/76] Converting op type: 'Add', name: 'Add_1', output_shape: (16, 2048).\n",
      "[SSAConverter] [51/76] Converting op type: 'Relu', name: 'Relu_1', output_shape: (16, 2048).\n",
      "[SSAConverter] [52/76] Converting op type: 'Minimum', name: 'Minimum_1', output_shape: (16, 2048).\n",
      "[SSAConverter] [53/76] Converting op type: 'MatMul', name: 'MatMul_2', output_shape: (16, 2048).\n",
      "[SSAConverter] [54/76] Converting op type: 'Add', name: 'Add_2', output_shape: (16, 2048).\n",
      "[SSAConverter] [55/76] Converting op type: 'Relu', name: 'Relu_2', output_shape: (16, 2048).\n",
      "[SSAConverter] [56/76] Converting op type: 'Minimum', name: 'Minimum_2', output_shape: (16, 2048).\n",
      "[SSAConverter] [57/76] Converting op type: 'Reshape', name: 'Reshape_1', output_shape: (16, 1, 2048).\n",
      "[SSAConverter] [58/76] Converting op type: 'LSTMBlock', name: 'lstm_fused_cell/BlockLSTM/LSTMBlock'.\n",
      "[SSAConverter] [59/76] Converting op type: 'get_tuple', name: 'lstm_fused_cell/BlockLSTM/get_tuple', output_shape: (16, 1, 2048).\n",
      "[SSAConverter] [60/76] Converting op type: 'get_tuple', name: 'lstm_fused_cell/BlockLSTM/get_tuple_0', output_shape: (16, 1, 2048).\n",
      "[SSAConverter] [61/76] Converting op type: 'Mul', name: 'lstm_fused_cell/mul', output_shape: (16, 1, 2048).\n",
      "[SSAConverter] [62/76] Converting op type: 'ConcatV2', name: 'lstm_fused_cell/concat', output_shape: (17, 1, 2048).\n",
      "[SSAConverter] [63/76] Converting op type: 'ConcatV2', name: 'lstm_fused_cell/concat_1', output_shape: (17, 1, 2048).\n",
      "[SSAConverter] [64/76] Converting op type: 'Reshape', name: 'Reshape_2', output_shape: (16, 2048).\n",
      "[SSAConverter] [65/76] Converting op type: 'GatherNd', name: 'lstm_fused_cell/GatherNd', output_shape: (1, 2048).\n",
      "[SSAConverter] [66/76] Converting op type: 'GatherNd', name: 'lstm_fused_cell/GatherNd_1', output_shape: (1, 2048).\n",
      "[SSAConverter] [67/76] Converting op type: 'MatMul', name: 'MatMul_3', output_shape: (16, 2048).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[SSAConverter] [68/76] Converting op type: 'set_global', name: 'Assign_2', output_shape: (1, 2048).\n",
      "[SSAConverter] [69/76] Converting op type: 'set_global', name: 'Assign_3', output_shape: (1, 2048).\n",
      "[SSAConverter] [70/76] Converting op type: 'Add', name: 'Add_3', output_shape: (16, 2048).\n",
      "[SSAConverter] [71/76] Converting op type: 'Relu', name: 'Relu_3', output_shape: (16, 2048).\n",
      "[SSAConverter] [72/76] Converting op type: 'Minimum', name: 'Minimum_3', output_shape: (16, 2048).\n",
      "[SSAConverter] [73/76] Converting op type: 'MatMul', name: 'MatMul_4', output_shape: (16, 29).\n",
      "[SSAConverter] [74/76] Converting op type: 'Add', name: 'Add_4', output_shape: (16, 29).\n",
      "[SSAConverter] [75/76] Converting op type: 'Reshape', name: 'raw_logits', output_shape: (16, 1, 29).\n",
      "[SSAConverter] [76/76] Converting op type: 'Softmax', name: 'logits', output_shape: (16, 1, 29).\n",
      "[Core ML Pass] 15 disconnected constants nodes deleted\n"
     ]
    }
   ],
   "source": [
    "tfmodel_path = './models/output_graph.pb'  # path to the downloaded model\n",
    "mlmodel_path = './deep_speech.mlmodel'  # path to save converted Core ML model\n",
    "\n",
    "# convert model and save to the local directory\n",
    "model = tfcoreml.convert(\n",
    "    tf_model_path=tfmodel_path, \n",
    "    mlmodel_path=mlmodel_path,\n",
    "        input_name_shape_dict={\n",
    "        'input_node': [1, 16, 19, 26],\n",
    "        'input_lengths': [1],\n",
    "        'previous_state_h__invar__': [1, 2048],\n",
    "        'previous_state_c__invar__': [1, 2048]\n",
    "    },\n",
    "    output_feature_names=['logits'],\n",
    "    minimum_ios_deployment_target='13'\n",
    ")\n",
    "\n",
    "# Optionally, we can print and inspect converted Core ML model\n",
    "# from coremltools.models.neural_network.printer import print_network_spec_coding_style\n",
    "# print_network_spec_coding_style(model.get_spec())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate some random data as inputs\n",
    "input_node = np.random.rand(1, 16, 19, 26)\n",
    "input_lengths = np.array([16], dtype=np.int32)\n",
    "previous_state_c = np.random.rand(1, 2048)\n",
    "previous_state_h = np.random.rand(1, 2048)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run predictions\n",
    "out = model.predict({\n",
    "    'input_node': input_node,\n",
    "    'input_lengths': input_lengths,\n",
    "    'previous_state_h__invar__': previous_state_h,\n",
    "    'previous_state_c__invar__': previous_state_c\n",
    "})['logits']\n",
    "\n",
    "output = np.array(out)\n",
    "# print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optionally we can verify the predictions are consistant with TensorFlow's output\n",
    "from tensorflow.contrib.rnn import *\n",
    "\n",
    "with open(tfmodel_path, 'rb') as f:\n",
    "    serialized = f.read()\n",
    "original_gdef = tf.compat.v1.GraphDef()\n",
    "original_gdef.ParseFromString(serialized)\n",
    "\n",
    "tf.import_graph_def(original_gdef, name=\"\")\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    g = sess.graph\n",
    "    out = g.get_tensor_by_name('Softmax:0')\n",
    "    in1 = g.get_tensor_by_name('input_node:0')\n",
    "    in2 = g.get_tensor_by_name('input_lengths:0')\n",
    "    in3 = g.get_tensor_by_name('previous_state_c:0')\n",
    "    in4 = g.get_tensor_by_name('previous_state_h:0')\n",
    "\n",
    "    tf_out = sess.run(out, feed_dict={\n",
    "        in1: input_node, in2: input_lengths, in3: previous_state_c, in4: previous_state_h,\n",
    "    })\n",
    "\n",
    "tf_output = np.array(tf_out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.testing.assert_array_equal(output.shape, tf_output.shape)\n",
    "np.testing.assert_almost_equal(output.flatten(), tf_output.flatten(), decimal=2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
